{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "authorship_tag": "ABX9TyNRcUVF3ZzTw+oK4ortpcH+",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/ferluht/Mubert-Text-to-Music/blob/main/Mubert_Text_to_Music.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# **Mubert Text to Music ✍ ➡ 🎹🎵🔊**\n",
        "\n",
        "A simple notebook demonstrating prompt-based music generation via [Mubert](https://mubert.com) [API](https://mubert2.docs.apiary.io/)"
      ],
      "metadata": {
        "id": "gHJPhnu7Lg2v"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GPdDFKWVVnif"
      },
      "outputs": [],
      "source": [
        "#@title **Setup Environment**\n",
        "\n",
        "import subprocess, time\n",
        "print(\"Setting up environment...\")\n",
        "start_time = time.time()\n",
        "all_process = [\n",
        "    ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
        "    ['pip', 'install', '-U', 'sentence-transformers'],\n",
        "    ['pip', 'install', 'httpx'],\n",
        "]\n",
        "for process in all_process:\n",
        "    running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
        "\n",
        "end_time = time.time()\n",
        "print(f\"Environment set up in {end_time-start_time:.0f} seconds\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title **Define Mubert methods and pre-compute things**\n",
        "\n",
        "import numpy as np\n",
        "from sentence_transformers import SentenceTransformer\n",
        "minilm = SentenceTransformer('all-MiniLM-L6-v2')\n",
        "\n",
        "mubert_tags_string = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'\n",
        "mubert_tags = np.array(mubert_tags_string.split(','))\n",
        "mubert_tags_embeddings = minilm.encode(mubert_tags)\n",
        "\n",
        "from IPython.display import Audio, display\n",
        "import httpx\n",
        "import json\n",
        "\n",
        "def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False):\n",
        "  if loop:\n",
        "    mode = \"loop\"\n",
        "  else:\n",
        "    mode = \"track\"\n",
        "  r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', \n",
        "      json={\n",
        "          \"method\":\"RecordTrackTTM\",\n",
        "          \"params\": {\n",
        "              \"pat\": pat, \n",
        "              \"duration\": duration,\n",
        "              \"tags\": tags,\n",
        "              \"mode\": mode\n",
        "          }\n",
        "      })\n",
        "\n",
        "  rdata = json.loads(r.text)\n",
        "  assert rdata['status'] == 1, rdata['error']['text']\n",
        "  trackurl = rdata['data']['tasks'][0]['download_link']\n",
        "\n",
        "  print('Generating track ', end='')\n",
        "  for i in range(maxit):\n",
        "      r = httpx.get(trackurl)\n",
        "      if r.status_code == 200:\n",
        "          display(Audio(trackurl, autoplay=autoplay))\n",
        "          break\n",
        "      time.sleep(1)\n",
        "      print('.', end='')\n",
        "\n",
        "def find_similar(em, embeddings, method='cosine'):\n",
        "    scores = []\n",
        "    for ref in embeddings:\n",
        "        if method == 'cosine': \n",
        "            scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em)))\n",
        "        if method == 'norm': \n",
        "            scores.append(np.linalg.norm(ref - em))\n",
        "    return np.array(scores), np.argsort(scores)\n",
        "\n",
        "def get_tags_for_prompts(prompts, top_n=3, debug=False):\n",
        "    prompts_embeddings = minilm.encode(prompts)\n",
        "    ret = []\n",
        "    for i, pe in enumerate(prompts_embeddings):\n",
        "        scores, idxs = find_similar(pe, mubert_tags_embeddings)\n",
        "        top_tags = mubert_tags[idxs[:top_n]]\n",
        "        top_prob = 1 - scores[idxs[:top_n]]\n",
        "        if debug:\n",
        "            print(f\"Prompt: {prompts[i]}\\nTags: {', '.join(top_tags)}\\nScores: {top_prob}\\n\\n\\n\")\n",
        "        ret.append((prompts[i], list(top_tags)))\n",
        "    return ret"
      ],
      "metadata": {
        "cellView": "form",
        "id": "yW-3aTNYvKM_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown **Get personal access token in Mubert and define API methods**\n",
        "email = \"your e-mail here\" #@param {type:\"string\"}\n",
        "\n",
        "r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', \n",
        "    json={\n",
        "        \"method\":\"GetServiceAccess\",\n",
        "        \"params\": {\n",
        "            \"email\": email,\n",
        "            \"license\":\"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5\",\n",
        "            \"token\":\"4951f6428e83172a4f39de05d5b3ab10d58560b8\",\n",
        "            \"mode\": \"loop\"\n",
        "        }\n",
        "    })\n",
        "\n",
        "rdata = json.loads(r.text)\n",
        "assert rdata['status'] == 1, \"probably incorrect e-mail\"\n",
        "pat = rdata['data']['pat']\n",
        "print(f'Got token: {pat}')"
      ],
      "metadata": {
        "cellView": "form",
        "id": "a4ACdvWLRJ5U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title **Generate some music 🎵**\n",
        "\n",
        "prompt = 'vladimir lenin smoking weed with bob marley' #@param {type:\"string\"}\n",
        "duration = 30 #@param {type:\"number\"}\n",
        "loop = False #@param {type:\"boolean\"}\n",
        "\n",
        "def generate_track_by_prompt(prompt, duration, loop=False):\n",
        "  _, tags = get_tags_for_prompts([prompt,])[0]\n",
        "  try:\n",
        "    get_track_by_tags(tags, pat, duration, autoplay=True, loop=loop)\n",
        "  except Exception as e:\n",
        "    print(str(e))\n",
        "  print('\\n')\n",
        "\n",
        "generate_track_by_prompt(prompt, duration, loop)\n"
      ],
      "metadata": {
        "cellView": "form",
        "id": "hTf7sZcfbI0K"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### **Batch generation 🎶**"
      ],
      "metadata": {
        "id": "wSKTfub-bitp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "duration = 60\n",
        "\n",
        "prompts = [\n",
        "    'kind beaver guards life tree, stan lee, epic',\n",
        "    'astronaut riding a horse',\n",
        "    'winnie the pooh cooking methamphetamine',\n",
        "    'vladimir lenin smoking weed with bob marley',\n",
        "    'soviet retrofuturism',\n",
        "    'two wasted friends high on weed are trying to navigate their way to their hostel in a big city, night, trippy',\n",
        "    'an elephant levitating on a gas balloon',\n",
        "    'calm music',\n",
        "    'a refrigerator floating in a pond'\n",
        "]\n",
        "\n",
        "tags = get_tags_for_prompts(prompts)\n",
        "\n",
        "for i, tag in enumerate(tags):\n",
        "  print(f'Prompt: {tag[0]}\\nTags: {tag[1]}')\n",
        "  try:\n",
        "    get_track_by_tags(tag[1], pat, duration, autoplay=False)\n",
        "  except Exception as e:\n",
        "    print(str(e))\n",
        "  print('\\n')"
      ],
      "metadata": {
        "id": "BzrhcwIHXlg0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "JieLk6kjZFai"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}